{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:13:17.875225Z",
     "start_time": "2020-07-08T08:13:15.547995Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.5.1'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import random\n",
    "import time\n",
    "import heapq\n",
    "import copy\n",
    "import gc\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "tqdm.pandas()\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.model_selection import KFold,StratifiedKFold\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pack_sequence, pad_packed_sequence, pad_sequence\n",
    "from torch.utils.data import DataLoader, Dataset, SequentialSampler\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'\n",
    "torch.distributed.init_process_group(backend=\"nccl\", init_method='tcp://localhost:23456', rank=0, world_size=1)\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:13:17.894416Z",
     "start_time": "2020-07-08T08:13:17.877007Z"
    }
   },
   "outputs": [],
   "source": [
    "# set random seeds to keep the results identical\n",
    "def setup_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "def worker_init_fn(worker_id):\n",
    "    setup_seed(GLOBAL_SEED)\n",
    "    \n",
    "GLOBAL_SEED = 2020\n",
    "setup_seed(GLOBAL_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:13:17.913126Z",
     "start_time": "2020-07-08T08:13:17.897885Z"
    }
   },
   "outputs": [],
   "source": [
    "data_path = './processed_data/'\n",
    "model_save = './model_save/'\n",
    "embedding_path = './embedding/'\n",
    "res_path = './result/'\n",
    "if not os.path.exists(model_save):\n",
    "    os.makedirs(model_save)\n",
    "if not os.path.exists(res_path):\n",
    "    os.makedirs(res_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:14:14.655346Z",
     "start_time": "2020-07-08T08:13:17.914476Z"
    }
   },
   "outputs": [],
   "source": [
    "df = pd.read_pickle(os.path.join(data_path, 'processed_data_numerical.pkl'))\n",
    "df['age'] = df['age'] - 1\n",
    "df['gender'] = df['gender'] - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:14:14.686446Z",
     "start_time": "2020-07-08T08:14:14.657103Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>time</th>\n",
       "      <th>creative_id</th>\n",
       "      <th>click_times</th>\n",
       "      <th>ad_id</th>\n",
       "      <th>product_id</th>\n",
       "      <th>product_category</th>\n",
       "      <th>advertiser_id</th>\n",
       "      <th>industry</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>user_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[20, 20, 20, 39, 40, 43, 46, 52, 60, 64, 64, 7...</td>\n",
       "      <td>[877468, 209778, 821396, 1683713, 122032, 7169...</td>\n",
       "      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2]</td>\n",
       "      <td>[773445, 188507, 724607, 1458878, 109959, 6621...</td>\n",
       "      <td>[44315, 136, 44315, 44315, 1334, 44315, 44315,...</td>\n",
       "      <td>[5, 2, 5, 5, 2, 18, 5, 5, 18, 2, 2, 2, 2]</td>\n",
       "      <td>[29455, 9702, 7293, 14668, 11411, 14681, 17189...</td>\n",
       "      <td>[106, 6, 326, 326, 336, 326, 73, 217, 64, 245,...</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                      time  \\\n",
       "user_id                                                      \n",
       "1        [20, 20, 20, 39, 40, 43, 46, 52, 60, 64, 64, 7...   \n",
       "\n",
       "                                               creative_id  \\\n",
       "user_id                                                      \n",
       "1        [877468, 209778, 821396, 1683713, 122032, 7169...   \n",
       "\n",
       "                                     click_times  \\\n",
       "user_id                                            \n",
       "1        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2]   \n",
       "\n",
       "                                                     ad_id  \\\n",
       "user_id                                                      \n",
       "1        [773445, 188507, 724607, 1458878, 109959, 6621...   \n",
       "\n",
       "                                                product_id  \\\n",
       "user_id                                                      \n",
       "1        [44315, 136, 44315, 44315, 1334, 44315, 44315,...   \n",
       "\n",
       "                                  product_category  \\\n",
       "user_id                                              \n",
       "1        [5, 2, 5, 5, 2, 18, 5, 5, 18, 2, 2, 2, 2]   \n",
       "\n",
       "                                             advertiser_id  \\\n",
       "user_id                                                      \n",
       "1        [29455, 9702, 7293, 14668, 11411, 14681, 17189...   \n",
       "\n",
       "                                                  industry  age  gender  \n",
       "user_id                                                                  \n",
       "1        [106, 6, 326, 326, 336, 326, 73, 217, 64, 245,...  3.0     0.0  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取预训练好的Word Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:14:14.738388Z",
     "start_time": "2020-07-08T08:14:14.723360Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['fasttext',\n",
       " 'embedding_w2v_sg1_hs0_win10_size300.npz',\n",
       " '.ipynb_checkpoints',\n",
       " 'word2vec',\n",
       " 'glove',\n",
       " '.empty',\n",
       " 'embedding_w2v_sg1_hs0_win10_size128.npz']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(embedding_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:15:01.289575Z",
     "start_time": "2020-07-08T08:14:14.803784Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "47"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding = np.load(os.path.join(embedding_path, 'embedding_w2v_sg1_hs0_win10_size300.npz'))\n",
    "creative = embedding['creative_w2v']\n",
    "ad= embedding['ad_w2v']\n",
    "advertiser = embedding['advertiser_w2v']\n",
    "product = embedding['product_w2v']\n",
    "industry = embedding['industry_w2v']\n",
    "product_cate = embedding['product_cate_w2v']\n",
    "del embedding\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 需要使用的embedding特征以及对应的序列编号"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:20:15.556911Z",
     "start_time": "2020-07-08T08:15:01.291332Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4000000/4000000 [08:07<00:00, 8211.46it/s]\n"
     ]
    }
   ],
   "source": [
    "# 这里将需要使用到的特征列直接拼接成一个向量，后面直接split即可\n",
    "data_seq = df[['creative_id', 'ad_id', 'advertiser_id', 'product_id', 'industry', 'click_times']].progress_apply(lambda s: np.hstack(s.values), axis=1).values\n",
    "\n",
    "# embedding_list = [creative_embed, ad_embed, advertiser_embed, product_embed]\n",
    "# embedding_list = [creative_glove, ad_glove, advertiser_glove, product_glove]\n",
    "embedding_list = [creative, ad, advertiser, product, industry]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立PyTorch Dataset 和 Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:20:15.580940Z",
     "start_time": "2020-07-08T08:20:15.559053Z"
    }
   },
   "outputs": [],
   "source": [
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, seqs, labels, input_num, shuffle=False):\n",
    "        self.seqs = seqs\n",
    "        self.labels = labels\n",
    "        self.input_num = input_num\n",
    "        self.shuffle = shuffle\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.seqs)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        length = int(self.seqs[idx].shape[0]/self.input_num)\n",
    "        seq_list = list(torch.LongTensor(self.seqs[idx]).split(length, dim=0))          \n",
    "        label = torch.LongTensor(self.labels[idx])\n",
    "        # 对数据进行随机shuffle\n",
    "        if self.shuffle and torch.rand(1) < 0.5:\n",
    "            random_pos = torch.randperm(length)\n",
    "            for i in range(len(seq_list)):\n",
    "                seq_list[i] = seq_list[i][random_pos]\n",
    "        return seq_list + [length, label]\n",
    "\n",
    "    \n",
    "def pad_truncate(Batch):\n",
    "    *seqs, lengths, labels = list(zip(*Batch))\n",
    "    # 长度截取到99%的大小，可以缩短pad长度，大大节省显存\n",
    "    trun_len = torch.topk(torch.tensor(lengths), max(int(0.01*len(lengths)), 1))[0][-1]\n",
    "    # 保险起见，再设置一个最大长度\n",
    "    max_len = min(trun_len, 150)\n",
    "    seq_list = list(pad_sequence(seq, batch_first=True)[:, :max_len] for seq in seqs)\n",
    "    return seq_list, torch.tensor(lengths).clamp_max(max_len), torch.stack(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:20:22.494750Z",
     "start_time": "2020-07-08T08:20:15.582588Z"
    }
   },
   "outputs": [],
   "source": [
    "input_num = 6\n",
    "BATCH_SIZE_TRAIN = 1024\n",
    "BATCH_SIZE_VAL = 2048\n",
    "BATCH_SIZE_TEST = 2048\n",
    "kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)\n",
    "data_folds = []\n",
    "valid_indexs = [] # 用于后面保存五折的验证集结果时，按照1到900000对应顺序\n",
    "for idx, (train_index, valid_index) in enumerate(kf.split(X=df.iloc[:3000000], y=df.iloc[:3000000]['age'])):\n",
    "    valid_indexs.append(valid_index)\n",
    "    X_train, X_val, X_test = data_seq[train_index], data_seq[valid_index], data_seq[3000000:]\n",
    "    y_train, y_val =  np.array(df.iloc[train_index, -2:]), np.array(df.iloc[valid_index, -2:])\n",
    "    y_test = np.random.rand(X_test.shape[0], 2)\n",
    "    \n",
    "    train_dataset = CustomDataset(X_train, y_train, input_num, shuffle=True)\n",
    "    val_dataset = CustomDataset(X_val, y_val, input_num, shuffle=False)\n",
    "    test_dataset = CustomDataset(X_test, y_test, input_num, shuffle=False)\n",
    "\n",
    "    train_dataloader = DataLoader(train_dataset, \n",
    "                                  batch_size=BATCH_SIZE_TRAIN, \n",
    "                                  shuffle=True, \n",
    "                                  collate_fn=pad_truncate, \n",
    "                                  num_workers=0, \n",
    "                                  worker_init_fn=worker_init_fn)\n",
    "    \n",
    "    valid_dataloader = DataLoader(val_dataset, \n",
    "                                  batch_size=BATCH_SIZE_VAL, \n",
    "                                  sampler=SequentialSampler(val_dataset), \n",
    "                                  shuffle=False, \n",
    "                                  collate_fn=pad_truncate, \n",
    "                                  num_workers=0, \n",
    "                                  worker_init_fn=worker_init_fn)\n",
    "    \n",
    "    test_dataloader = DataLoader(test_dataset, \n",
    "                                 batch_size=BATCH_SIZE_TEST, \n",
    "                                 sampler=SequentialSampler(test_dataset), \n",
    "                                 shuffle=False, \n",
    "                                 collate_fn=pad_truncate, \n",
    "                                 num_workers=0, \n",
    "                                 worker_init_fn=worker_init_fn)\n",
    "    data_folds.append((train_dataloader, valid_dataloader, test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:20:22.576290Z",
     "start_time": "2020-07-08T08:20:22.496809Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "del data_seq, creative, ad, advertiser, product, industry, product_cate\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:20:22.605152Z",
     "start_time": "2020-07-08T08:20:22.578300Z"
    }
   },
   "outputs": [],
   "source": [
    "class BiLSTM(nn.Module):\n",
    "    def __init__(self, embedding_list, embedding_freeze, lstm_size, fc1, fc2, num_layers=1, rnn_dropout=0.2, embedding_dropout=0.2, fc_dropout=0.2):\n",
    "        super().__init__()\n",
    "        self.embedding_layers = nn.ModuleList([nn.Embedding.from_pretrained(torch.HalfTensor(embedding).cuda(), freeze=freeze) for embedding, freeze in zip(embedding_list, embedding_freeze)])\n",
    "        self.input_dim = int(np.sum([embedding.shape[1] for embedding in embedding_list]))\n",
    "        self.lstm = nn.LSTM(input_size = self.input_dim, \n",
    "                                      hidden_size = lstm_size, \n",
    "                                      num_layers = num_layers,\n",
    "                                      bidirectional = True, \n",
    "                                      batch_first = True, \n",
    "                                      dropout = rnn_dropout) \n",
    "                                                  \n",
    "        \n",
    "        \n",
    "        self.fc1 = nn.Linear(2*lstm_size, fc1)\n",
    "        self.fc2 = nn.Linear(fc1, fc2)\n",
    "        self.fc3 = nn.Linear(fc2, 12)\n",
    "        \n",
    "        self.rnn_dropout = nn.Dropout(rnn_dropout)\n",
    "        self.embedding_dropout = nn.Dropout(embedding_dropout)\n",
    "        self.fc_dropout = nn.Dropout(fc_dropout)\n",
    "  \n",
    "    \n",
    "    def forward(self, seq_list, lengths):\n",
    "        batch_size, total_length= seq_list[0].size()\n",
    "        lstm_outputs = []\n",
    "        click_time = seq_list[-1]\n",
    "        embeddings = []\n",
    "        for idx, seq in enumerate(seq_list[:-1]):\n",
    "            embedding = self.embedding_layers[idx](seq).to(torch.float32)\n",
    "            embedding = self.embedding_dropout(embedding)\n",
    "            embeddings.append(embedding)\n",
    "        packed = pack_padded_sequence(torch.cat(embeddings, dim=-1), lengths, batch_first=True, enforce_sorted=False)\n",
    "        packed_output, (h_n, c_n) = self.lstm(packed)\n",
    "        lstm_output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=total_length, padding_value=-float('inf'))\n",
    "        lstm_output = self.rnn_dropout(lstm_output)\n",
    "        # lstm_output shape: (batchsize, total_length, 2*lstm_size)\n",
    "        max_output = F.max_pool2d(lstm_output, (total_length, 1), stride=(1, 1)).squeeze()\n",
    "        # output shape: (batchsize, 2*lstm_size)\n",
    "        fc_out = F.relu(self.fc1(max_output))\n",
    "        fc_out = self.fc_dropout(fc_out)\n",
    "        fc_out = F.relu(self.fc2(fc_out))\n",
    "        pred = self.fc3(fc_out)\n",
    "        age_pred = pred[:, :10]\n",
    "        gender_pred = pred[:, -2:]\n",
    "        return age_pred, gender_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-08T08:20:22.643822Z",
     "start_time": "2020-07-08T08:20:22.606860Z"
    }
   },
   "outputs": [],
   "source": [
    "def validate(model, val_dataloader, criterion, history, n_iters):\n",
    "    model.eval()\n",
    "    global best_acc, best_model, validate_history\n",
    "    costs = []\n",
    "    age_accs = []\n",
    "    gender_accs = []\n",
    "    with torch.no_grad():\n",
    "        for idx, batch in enumerate(val_dataloader):\n",
    "            seq_list, lengths, labels = batch\n",
    "            seq_list_device = [seq.cuda() for seq in seq_list]\n",
    "            lengths_device = lengths.cuda()\n",
    "            labels = labels.cuda()\n",
    "            age_output, gender_output = model(seq_list_device, lengths_device)    \n",
    "            loss = criterion(age_output, gender_output, labels)\n",
    "            costs.append(loss.item())\n",
    "            _, age_preds = torch.max(age_output, 1)\n",
    "            _, gender_preds = torch.max(gender_output, 1)\n",
    "            age_accs.append((age_preds == labels[:, 0]).float().mean().item())\n",
    "            gender_accs.append((gender_preds == labels[:, 1]).float().mean().item())\n",
    "            torch.cuda.empty_cache()\n",
    "    mean_accs = np.mean(age_accs) + np.mean(gender_accs)\n",
    "    mean_costs = np.mean(costs)\n",
    "    writer.add_scalar('gender/validate_accuracy', np.mean(gender_accs), n_iters)\n",
    "    writer.add_scalar('gender/validate_loss', mean_costs, n_iters)\n",
    "    writer.add_scalar('age/validate_accuracy',np.mean(age_accs), n_iters)\n",
    "    writer.add_scalar('age/validate_loss', mean_costs, n_iters)\n",
    "    if mean_accs > history['best_model'][0][0]:  \n",
    "        save_dict = copy.deepcopy(model.state_dict())\n",
    "        embedding_keys = []\n",
    "        for key in save_dict.keys():\n",
    "            if key.startswith('embedding'):\n",
    "                embedding_keys.append(key)\n",
    "        for key in embedding_keys:\n",
    "            save_dict.pop(key)\n",
    "        heapq.heapify(history['best_model'])\n",
    "        checkpoint_pth = history['best_model'][0][1]\n",
    "        heapq.heappushpop(history['best_model'], (mean_accs, checkpoint_pth))\n",
    "        torch.save(save_dict, checkpoint_pth)\n",
    "        del save_dict\n",
    "        gc.collect()\n",
    "        torch.cuda.empty_cache()\n",
    "    return mean_costs, mean_accs\n",
    "\n",
    "\n",
    "def train(model, train_dataloader, val_dataloader, criterion, optimizer, epoch, history, validate_points, scheduler, step=True):\n",
    "    model.train()\n",
    "    costs = []\n",
    "    age_accs = []\n",
    "    gender_accs = []\n",
    "    val_loss, val_acc = 0, 0\n",
    "    with tqdm(total=len(train_dataloader.dataset), desc='Epoch{}'.format(epoch)) as pbar:\n",
    "        for idx, batch in enumerate(train_dataloader):\n",
    "            seq_list, lengths, labels = batch\n",
    "            seq_list_device = [seq.cuda() for seq in seq_list]\n",
    "            lengths_device = lengths.cuda()\n",
    "            labels = labels.cuda()\n",
    "            age_output, gender_output = model(seq_list_device, lengths_device)    \n",
    "            loss = criterion(age_output, gender_output, labels)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if step:\n",
    "                scheduler.step()\n",
    "            with torch.no_grad():\n",
    "                costs.append(loss.item())\n",
    "                _, age_preds = torch.max(age_output, 1)\n",
    "                _, gender_preds = torch.max(gender_output, 1)\n",
    "                age_accs.append((age_preds == labels[:, 0]).float().mean().item())\n",
    "                gender_accs.append((gender_preds == labels[:, 1]).float().mean().item())\n",
    "                pbar.update(labels.size(0))\n",
    "            n_iters = idx + len(train_dataloader)*(epoch-1)\n",
    "            if idx in validate_points:\n",
    "                val_loss, val_acc = validate(model, val_dataloader, criterion, history, n_iters)\n",
    "                model.train()\n",
    "            \n",
    "            writer.add_scalar('gender/train_accuracy', gender_accs[-1], n_iters)\n",
    "            writer.add_scalar('gender/train_loss', costs[-1], n_iters)\n",
    "            writer.add_scalar('age/train_accuracy', age_accs[-1], n_iters)\n",
    "            writer.add_scalar('age/train_loss', costs[-1], n_iters)\n",
    "            writer.add_scalar('age/learning_rate', scheduler.get_lr()[0], n_iters)\n",
    "            pbar.set_postfix_str('loss:{:.4f}, acc:{:.4f}, val-loss:{:.4f}, val-acc:{:.4f}'.format(np.mean(costs[-10:]), np.mean(age_accs[-10:])+np.mean(gender_accs[-10:]), val_loss, val_acc))\n",
    "            gc.collect()\n",
    "\n",
    "    \n",
    "def test(oof_train_test, model, test_dataloader, val_dataloader, valid_index, weight=1):\n",
    "    # 这里测试的时候对验证集也进行计算，以便于后续模型融合和search weight等提高\n",
    "    model.eval()\n",
    "    y_val = []\n",
    "    age_pred = []\n",
    "    gender_pred = []\n",
    "    age_pred_val = []\n",
    "    gender_pred_val = []\n",
    "    with torch.no_grad():\n",
    "        for idx, batch in enumerate(test_dataloader):\n",
    "            seq_list, lengths, labels = batch\n",
    "            seq_list_device = [seq.cuda() for seq in seq_list]\n",
    "            lengths_device = lengths.cuda()\n",
    "            age_output, gender_output = model(seq_list_device, lengths_device)    \n",
    "            age_pred.append(age_output.cpu())\n",
    "            gender_pred.append(gender_output.cpu())\n",
    "            torch.cuda.empty_cache()\n",
    "            \n",
    "        for idx, batch in enumerate(val_dataloader):\n",
    "            seq_list, lengths, labels = batch\n",
    "            seq_list_device = [seq.cuda() for seq in seq_list]\n",
    "            lengths_device = lengths.cuda()\n",
    "            age_output, gender_output = model(seq_list_device, lengths_device)\n",
    "            age_pred_val.append(age_output.cpu())\n",
    "            gender_pred_val.append(gender_output.cpu())\n",
    "            y_val.append(labels)\n",
    "            torch.cuda.empty_cache()\n",
    "            \n",
    "    # 0到9列存储age的预测概率分布，10列到11列存储gender的预测概率分布，12、13列分别存储age和gender的真实标签        \n",
    "    oof_train_test[valid_index, :10] += F.softmax(torch.cat(age_pred_val)).numpy() * weight\n",
    "    oof_train_test[valid_index, 10:12] += F.softmax(torch.cat(gender_pred_val)).numpy() * weight\n",
    "    oof_train_test[valid_index, 12:] = torch.cat(y_val).numpy()\n",
    "    oof_train_test[3000000:, :10] += F.softmax(torch.cat(age_pred)).numpy() * (1/5) * weight\n",
    "    oof_train_test[3000000:, 10:12] += F.softmax(torch.cat(gender_pred)).numpy() * (1/5) * weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-07-08T08:14:38.422Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch1: 100%|██████████| 2400000/2400000 [50:50<00:00, 786.68it/s, loss:0.7966, acc:1.4336, val-loss:0.7819, val-acc:1.4467]  \n",
      "Epoch2: 100%|██████████| 2400000/2400000 [50:54<00:00, 785.77it/s, loss:0.7723, acc:1.4552, val-loss:0.7679, val-acc:1.4556]  \n",
      "Epoch3: 100%|██████████| 2400000/2400000 [50:58<00:00, 784.80it/s, loss:0.7672, acc:1.4595, val-loss:0.7620, val-acc:1.4593]  \n",
      "Epoch4:  11%|█         | 258048/2400000 [04:28<37:09, 960.83it/s, loss:0.7468, acc:1.4634, val-loss:0.0000, val-acc:0.0000] "
     ]
    }
   ],
   "source": [
    "# 定义联合损失函数\n",
    "def criterion(age_output, gender_output, labels):\n",
    "    age_loss = nn.CrossEntropyLoss()(age_output, labels[:, 0])\n",
    "    gender_loss = nn.CrossEntropyLoss()(gender_output, labels[:, 1])\n",
    "    return age_loss*0.6 + gender_loss*0.4\n",
    "\n",
    "# 0到9列存储age的预测概率分布，10列到11列存储gender的预测概率分布，12、13列分别存储age和gender的真实标签\n",
    "oof_train_test = np.zeros((4000000, 14))\n",
    "# oof_train_test = np.load(os.path.join(model_save, \"lstm_v2_300size_fold_2.npy\"))\n",
    "\n",
    "acc_folds = []\n",
    "model_name = 'lstm_v2_300size_win10_dropout'\n",
    "best_checkpoint_num = 3\n",
    "for idx, (train_dataloader, val_dataloader, test_dataloader) in enumerate(data_folds):\n",
    "#     if idx in [0]:\n",
    "#         continue\n",
    "    history = {'best_model': []}\n",
    "    for i in range(best_checkpoint_num):\n",
    "        history['best_model'].append((0, os.path.join(model_save, '{}_checkpoint_{}.pth'.format(model_name, i))))\n",
    "    # 对应顺序: creative_w2v, ad_w2v, advertiser_w2v, product_w2v, industry_w2v\n",
    "    embedding_freeze = [True, True, True, True, True]\n",
    "    validate_points = list(np.linspace(0, len(train_dataloader)-1, 2).astype(int))[1:]\n",
    "    model = BiLSTM(embedding_list, embedding_freeze, lstm_size=1500, fc1=1500, fc2=800,  num_layers=2, rnn_dropout=0.1, fc_dropout=0.2, embedding_dropout=0.2)  \n",
    "    model = model.cuda()\n",
    "    model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.999), lr=1e-3)\n",
    "    epochs = 7\n",
    "#     scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)\n",
    "    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=2e-3, step_size_up=int(len(train_dataloader)/2), cycle_momentum=False, mode='triangular')\n",
    "#     scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=3e-3, epochs=epochs, steps_per_epoch=len(train_dataloader), pct_start=0.2, anneal_strategy='linear', div_factor=30, final_div_factor=1e4)\n",
    "    for epoch in range(1, epochs+1):\n",
    "        writer = SummaryWriter(log_dir='./runs/{}_fold_{}'.format(model_name, idx))\n",
    "        train(model, train_dataloader, val_dataloader, criterion, optimizer, epoch, history, validate_points, scheduler, step=True)\n",
    "#         scheduler.step()\n",
    "        gc.collect()\n",
    "    for (acc, checkpoint_pth), weight in zip(sorted(history['best_model'], reverse=True), [0.5, 0.3, 0.2]):\n",
    "        model.load_state_dict(torch.load(checkpoint_pth, map_location=torch.device('cpu')), strict=False)\n",
    "        test(oof_train_test, model, test_dataloader, val_dataloader, valid_indexs[idx], weight=weight)\n",
    "    acc_folds.append(sorted(history['best_model'], reverse=True)[0][0])\n",
    "    np.save(os.path.join(model_save, \"{}_fold_{}.npy\".format(model_name, idx)), oof_train_test)\n",
    "    del model\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-06T00:21:57.241854Z",
     "start_time": "2020-07-06T00:21:57.208328Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1.4644472867555587,\n",
       " 1.4632941303399642,\n",
       " 1.4649451368904765,\n",
       " 1.4653746609801725,\n",
       " 1.4639947543575496]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc_folds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-07-06T00:23:23.091488Z",
     "start_time": "2020-07-06T00:23:22.762955Z"
    }
   },
   "outputs": [],
   "source": [
    "np.save(os.path.join(res_path, \"{}_5folds_{:.4f}.npy\".format(model_name, np.mean(acc_folds))), oof_train_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_py3",
   "language": "python",
   "name": "conda_pytorch_py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
